nlp_architect.models.gnmt_model.GNMTModel

class nlp_architect.models.gnmt_model.GNMTModel(hparams, mode, iterator, source_vocab_table, target_vocab_table, reverse_target_vocab_table=None, scope=None, extra_args=None)[source]

Sequence-to-sequence dynamic model with GNMT attention architecture with sparsity policy support.

__init__(hparams, mode, iterator, source_vocab_table, target_vocab_table, reverse_target_vocab_table=None, scope=None, extra_args=None)[source]

Create the model.

Parameters
  • hparams – Hyperparameter configurations.

  • mode – TRAIN | EVAL | INFER

  • iterator – Dataset Iterator that feeds data.

  • source_vocab_table – Lookup table mapping source words to ids.

  • target_vocab_table – Lookup table mapping target words to ids.

  • reverse_target_vocab_table – Lookup table mapping ids to target words. Only required in INFER mode. Defaults to None.

  • scope – scope of the model.

  • extra_args – model_helper.ExtraArgs, for passing customizable functions.

Methods

__init__(hparams, mode, iterator, …[, …])

Create the model.

build_encoder_states([include_embeddings])

Stack encoder states and return tensor [batch, length, layer, size].

build_graph(hparams[, scope])

Subclass must implement this method.

decode(sess)

Decode a batch.

eval(sess)

Execute eval graph.

get_max_time(tensor)

infer(sess)

init_embeddings(hparams, scope)

Init embeddings.

train(sess)

Execute train graph.

build_encoder_states(include_embeddings=False)

Stack encoder states and return tensor [batch, length, layer, size].

build_graph(hparams, scope=None)

Subclass must implement this method.

Creates a sequence-to-sequence model with dynamic RNN decoder API. :param hparams: Hyperparameter configurations. :param scope: VariableScope for the created subgraph; default “dynamic_seq2seq”.

Returns

A tuple of the form (logits, loss_tuple, final_context_state, sample_id), where:

logits: float32 Tensor [batch_size x num_decoder_symbols]. loss: loss = the total loss / batch_size. final_context_state: the final state of decoder RNN. sample_id: sampling indices.

Raises

ValueError – if encoder_type differs from mono and bi, or attention_option is not (luong | scaled_luong | bahdanau | normed_bahdanau).

decode(sess)

Decode a batch.

Parameters

sess – tensorflow session to use.

Returns

A tuple consiting of outputs, infer_summary.

outputs: of size [batch_size, time]

eval(sess)

Execute eval graph.

get_max_time(tensor)
infer(sess)
init_embeddings(hparams, scope)

Init embeddings.

train(sess)

Execute train graph.